{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pickle\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "bert = pd.read_csv('../data/bert/bert.csv')\n",
    "\n",
    "# replace the AddNN-Ensemble-L0.868963-NB50000.csv with the new one which generated under second_level/preds\n",
    "# or you can get the name with file_name = [v for v in os.listdir('../data/ensemble/second_level/preds/') if v != '.gitkeep'][0]\n",
    "sec_stacking_df = pd.read_csv('../data/ensemble/second_level/preds/AddNN-Ensemble-L0.868963-NB50000.csv') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agreed 0.9441773531582828\n",
      "disagreed 0.8592555850590337\n",
      "unrelated 0.9301244635454418\n"
     ]
    }
   ],
   "source": [
    "for col in bert_df.columns:\n",
    "    print(col, bert_df[col].corr(sec_stacking_df[col]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred_labels = sec_stacking_df.idxmax(axis=1)\n",
    "sec_sub = pd.DataFrame({\"Category\": pred_labels})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pred_labels = bert_df.idxmax(axis=1)\n",
    "bert_sub = pd.DataFrame({\"Category\": pred_labels})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "bagging = 0.42 * bert_df +  0.58 * sec_stacking_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agreed 0.986968273495625\n",
      "disagreed 0.9775277504308344\n",
      "unrelated 0.9833833321969182\n"
     ]
    }
   ],
   "source": [
    "for col in bert_df.columns:\n",
    "    print(col, sec_stacking_df[col].corr(bagging[col]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "agreed 0.9848845315832347\n",
      "disagreed 0.9477836516321912\n",
      "unrelated 0.9813389063686446\n"
     ]
    }
   ],
   "source": [
    "for col in bert_df.columns:\n",
    "    print(col, bert_df[col].corr(bagging[col]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "ids = pd.read_csv(\"../data/dataset/test.csv\")\n",
    "\n",
    "pred_labels = bagging.idxmax(axis=1)\n",
    "sub = pd.DataFrame({\"Id\": ids['id'].values , \"Category\": pred_labels})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "unrelated    0.654220\n",
       "agreed       0.326186\n",
       "disagreed    0.019594\n",
       "Name: Category, dtype: float64"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub['Category'].value_counts() / len(sub)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3862"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(sec_sub['Category'] != sub['Category']).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2316"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(bert_sub['Category'] != sub['Category']).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sub.to_csv('../data/ensemble/second_level/EM.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "bagging.to_csv('../data/ensemble/second_level/FirstLevelPseudoLabels.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
